{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ca0ce66e-8c08-4c74-a5ae-60592173785d",
   "metadata": {},
   "source": [
    "# Using Simplicit's Easy API To Simulate Example Mesh\n",
    "[Simplicits](https://research.nvidia.com/labs/toronto-ai/simplicits/) is a mesh-free, representation-agnostic way to simulation elastic deformations. \n",
    "\n",
    "**In this notebook, we present a simple way to use the simplicit's code base.** We can create a simple object, train it, simulate it and visualize all in a very few lines of code via our `easy_api`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "006e80b7-ac7c-4a82-a065-406c01bb15f0",
   "metadata": {},
   "source": [
    "#### Requirements for this demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b05a6c16-2c2a-4a70-a785-dabfbee269b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q k3d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6dcc14f-a4a3-4c9c-87f8-5a28d6f5a2d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy, math, os, sys, logging, threading\n",
    "from typing import List, Tuple\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import kaolin as kal\n",
    "import warp as wp  \n",
    "\n",
    "from IPython.display import display\n",
    "\n",
    "import k3d\n",
    "import ipywidgets as widgets\n",
    "\n",
    "from ipywidgets import Button, HBox, VBox\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "#local logger, prints at info or above\n",
    "logging.basicConfig(level=logging.INFO, stream=sys.stdout)\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "#logger used in the api code\n",
    "logging.getLogger('kaolin.physics').setLevel(logging.INFO) # Prints everything at debug level or above\n",
    "\n",
    "sys.path.append(str(Path(\"..\")))\n",
    "from tutorial_common import COMMON_DATA_DIR\n",
    "\n",
    "def sample_mesh_path(fname):\n",
    "    return os.path.join(COMMON_DATA_DIR, 'meshes', fname)\n",
    "\n",
    "def print_tensor(t, name='', **kwargs):\n",
    "    print(kal.utils.testing.tensor_info(t, name=name, **kwargs))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "492694e2-a66e-43f0-b9b3-4ed36a421e59",
   "metadata": {},
   "source": [
    "## Loading Geometry\n",
    "Simplicits works with any geometry: meshes, pointclouds, SDFs, Gaussian splats, and more. For this example, we will use a mesh:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40a49d68-c73e-4a06-9708-695e2f201318",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import and triangulate to enable rasterization; move to GPU\n",
    "mesh = kal.io.import_mesh(sample_mesh_path('fox.obj'), triangulate=True).cuda()\n",
    "mesh.vertices = kal.ops.pointcloud.center_points(mesh.vertices.unsqueeze(0), normalize=True).squeeze(0) \n",
    "orig_vertices = mesh.vertices.clone()  # Also save original undeformed vertices\n",
    "print(mesh)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6e54cb8-e2d2-4921-814b-62ef9a7b431e",
   "metadata": {},
   "source": [
    "## Sample Geometry\n",
    "To enable simulation we need point samples within the object's volume, and physical material parameters per point. Lets set this up."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c11c603-4b6f-4229-9303-94c40c8d06b1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Physics material parameters\n",
    "soft_youngs_modulus = 1e5\n",
    "poisson_ratio = 0.45\n",
    "rho = 500.0  # kg/m^3\n",
    "approx_volume = 0.5  # m^3\n",
    "\n",
    "# Points sampled over the object's bounding box\n",
    "num_samples = 1000000\n",
    "uniform_pts = torch.rand(num_samples, 3, device='cuda') * (orig_vertices.max(dim=0).values - orig_vertices.min(dim=0).values) + orig_vertices.min(dim=0).values\n",
    "boolean_signs = kal.ops.mesh.check_sign(mesh.vertices.unsqueeze(0), mesh.faces, uniform_pts.unsqueeze(0), hash_resolution=512)\n",
    "\n",
    "# use pts within the object\n",
    "pts = uniform_pts[boolean_signs.squeeze()]\n",
    "yms = torch.full((pts.shape[0],), soft_youngs_modulus, device=\"cuda\")\n",
    "prs = torch.full((pts.shape[0],), poisson_ratio, device=\"cuda\")\n",
    "rhos = torch.full((pts.shape[0],), rho, device=\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0799e8c2-832d-48fe-8c80-16ce73f2a851",
   "metadata": {},
   "source": [
    "## Visualize Sample Points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5900b52a-8876-478f-bfb2-a2080a041dba",
   "metadata": {},
   "outputs": [],
   "source": [
    "import k3d\n",
    "plot = k3d.plot()\n",
    "plot += k3d.points(pts.cpu().detach().numpy(), point_size=0.01)\n",
    "plot.display()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66146660-e708-4cf7-815c-a1fd5d337053",
   "metadata": {},
   "source": [
    "## Create and Train a SimplicitsObject\n",
    "We encapsulate everything Simplicits method needs to know about the simulated object in a `SimplicitsObject` instance. Once the object is created, we need to run training to learn reduced degrees of freedom our simulator can use. \n",
    "\n",
    "**This will take a couple of minutes.** Please be patient."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "803fb77f",
   "metadata": {},
   "source": [
    "### New method for Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "861e7e85-e160-4c12-8ef5-e1b771901dd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_obj = kal.physics.simplicits.SimplicitsObject.create_trained(\n",
    "    pts, # sampled points\n",
    "    yms, # material stiffness\n",
    "    prs, # material compressibility ratio\n",
    "    rhos, # material density\n",
    "    approx_volume, # volue\n",
    "    num_handles=5, # skinning handles (DOFs)\n",
    "    training_num_steps=10000, \n",
    "    training_lr_start=1e-3,\n",
    "    training_lr_end=1e-3,\n",
    "    training_le_coeff=1e-1,\n",
    "    training_lo_coeff=1e6,\n",
    "    training_log_every=1000,\n",
    "    normalize_for_training=True)\n",
    "\n",
    "# Optionally, you can load a pre-trained object by loading the skinning function\n",
    "# skinning_fcn = torch.load(\"path_to_fcn.pth\")\n",
    "# sim_obj = kal.physics_sparse.simplicits_common.SimplicitsObject.create_from_function(\n",
    "#     pts, yms, prs, rhos, approx_volume, skinning_fcn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e0bb9b8-4d0a-43cc-9465-2ea966c515e5",
   "metadata": {},
   "source": [
    "## Create a Scene\n",
    "Now we are ready to set up all the forces in the scene to simulated as well as simulation settings. For example, here we will add gravity and a floor plane."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71ad0028-1686-4d4c-a6b5-5e5731ff0612",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene = kal.physics.simplicits.SimplicitsScene()  # default empty scene\n",
    "#Convergence might not be guaranteed with few newton iterations, but runs very fast\n",
    "scene.max_newton_steps = 5\n",
    "scene.timestep = 0.03\n",
    "scene.direct_solve = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09cc0c0d-1e78-42c0-845d-9c6665667c61",
   "metadata": {},
   "source": [
    " The same `SimplicitsObject` can be added to multiple scene. Let's add it to our scene. Not we can reference it within the scene using `obj_idx`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b27b0dc8-7673-4297-9403-35334c44443d",
   "metadata": {},
   "outputs": [],
   "source": [
    "obj_idx = scene.add_object(sim_obj)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6935147a-1e0e-444b-a0b9-a50232e4bf73",
   "metadata": {},
   "source": [
    "Lets set gravity and floor forces on the scene"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7c5640e-1008-44d1-a9ae-caa1aad58576",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene.set_scene_gravity(acc_gravity=torch.tensor([0, 9.8, 0]))\n",
    "scene.set_scene_floor(floor_height=-0.8, floor_axis=1, floor_penalty=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81d90e2d",
   "metadata": {},
   "source": [
    "# Display"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36b45eb6-6d0f-4443-90f5-bf805a4c4e53",
   "metadata": {},
   "source": [
    "## Set Up Rendering\n",
    "Let's set up rendering of our mesh so we can view it in a notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d06cff49-da6d-4806-bebb-588df336e6ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "resolution = 512\n",
    "camera = kal.render.easy_render.default_camera(resolution).cuda()\n",
    "\n",
    "light_direction = kal.render.lighting.sg_direction_from_azimuth_elevation(1., 1.)\n",
    "lighting = kal.render.lighting.SgLightingParameters(amplitude=3., sharpness=5., direction=light_direction).cuda()\n",
    "    \n",
    "def render(in_cam):\n",
    "    # render\n",
    "    active_pass=kal.render.easy_render.RenderPass.render\n",
    "    render_res = kal.render.easy_render.render_mesh(in_cam, mesh, lighting=lighting)\n",
    "\n",
    "    # create white background\n",
    "    img = render_res[active_pass]\n",
    "    background_mask = (render_res[kal.render.easy_render.RenderPass.face_idx] < 0).bool()\n",
    "    img2 = torch.clamp(img, 0, 1)[0]\n",
    "    img2[background_mask[0]] = 1\n",
    "    final = (img2 * 255.).to(torch.uint8)\n",
    "    return {\"img\":final, \"face_idx\": render_res[kal.render.easy_render.RenderPass.face_idx].squeeze(0).unsqueeze(-1)}\n",
    "\n",
    "# faster low-res render during mouse motion\n",
    "def fast_render(in_cam, factor=8):\n",
    "    lowres_cam = copy.deepcopy(in_cam)\n",
    "    lowres_cam.width = in_cam.width // factor\n",
    "    lowres_cam.height = in_cam.height // factor\n",
    "    return render(lowres_cam)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef421ba7-e245-4780-88f5-8fe10c2bb43d",
   "metadata": {},
   "source": [
    "## That's it! Let's Run and View or Physics Simulation\n",
    "All we need to do now is run simulation and display the object using Kaolin's in-notebook visualizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30963bf0-06be-4186-b4e0-ab7d377b1409",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reset mesh to its rest state\n",
    "mesh.vertices = orig_vertices\n",
    "\n",
    "\n",
    "global sim_thread_open, sim_thread\n",
    "sim_thread_open = False\n",
    "sim_thread = None\n",
    "\n",
    "def reset_simulation(visualizer):\n",
    "    global scene\n",
    "    with visualizer.out:\n",
    "        scene.reset_scene()\n",
    "    mesh.vertices = scene.get_object_deformed_pts(obj_idx, orig_vertices)\n",
    "    visualizer.render_update()\n",
    "    \n",
    "\n",
    "def run_sim():\n",
    "    for s in range(100):\n",
    "        with visualizer.out:\n",
    "            scene.run_sim_step()\n",
    "            print(\".\", end=\"\")\n",
    "        mesh.vertices = scene.get_object_deformed_pts(obj_idx, orig_vertices)\n",
    "\n",
    "        visualizer.render_update()\n",
    "\n",
    "def start_simulation(b):\n",
    "    global sim_thread_open, sim_thread\n",
    "    with visualizer.out:\n",
    "        if(sim_thread_open):\n",
    "            sim_thread.join()\n",
    "            sim_thread_open = False\n",
    "        sim_thread_open = True\n",
    "        sim_thread = threading.Thread(target=run_sim, daemon=True)\n",
    "        sim_thread.start()\n",
    "\n",
    "visualizer = kal.visualize.IpyTurntableVisualizer(\n",
    "    resolution, resolution, copy.deepcopy(camera), render, fast_render=fast_render,\n",
    "    max_fps=24, world_up_axis=1)\n",
    "\n",
    "buttons = [Button(description=x) for x in\n",
    "           ['Run Sim', 'Reset']]\n",
    "buttons[0].on_click(lambda e: start_simulation(e))\n",
    "buttons[1].on_click(lambda e: reset_simulation(visualizer))\n",
    "\n",
    "\n",
    "reset_simulation(visualizer)\n",
    "display(HBox([visualizer.canvas, VBox(buttons)]), visualizer.out)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
